package com.wxsm.wechat.core.jpa;

import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;

import javax.persistence.EntityManager;
import javax.persistence.criteria.*;
import java.io.Serializable;
import java.util.List;

/**
 * Created with Yang Huan
 * Date: 2017/3/2
 * Time: 11:34
 */
public class MyRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, ID> implements MyRepository<T, ID> {

    private final Class<T> domainClass;
    private final EntityManager entityManager;

    public MyRepositoryImpl(Class<T> domainClass, EntityManager entityManager) {
        super(domainClass, entityManager);
        this.domainClass = domainClass;
        this.entityManager = entityManager;
    }

    public Page<T> findPage(JpaSelect<T> jpaSelect, Pageable pageable) {
        Assert.notNull(pageable, "分页参数不能为空");
        CriteriaBuilder criteriaBuilder = entityManager.getCriteriaBuilder();
        CriteriaQuery<T> criteriaQuery = criteriaBuilder.createQuery(domainClass);
        Root<T> root = criteriaQuery.from(domainClass);
        JpaSelection jpaSelectInfo = jpaSelect.getSelection(root, criteriaBuilder);
        List<T> resultList = entityManager.createQuery(this.getFindAllQuery(jpaSelectInfo, criteriaQuery)).
                setFirstResult(pageable.getOffset()).
                setMaxResults(pageable.getPageSize()).
                getResultList();
        return new PageImpl<>(resultList, pageable, this.getCount(jpaSelectInfo, criteriaBuilder));
    }

    public List<T> findAll(JpaSelect<T> jpaSelect) {
        CriteriaBuilder criteriaBuilder = entityManager.getCriteriaBuilder();
        CriteriaQuery<T> criteriaQuery = criteriaBuilder.createQuery(domainClass);
        Root<T> root = criteriaQuery.from(domainClass);
        JpaSelection jpaSelectInfo = jpaSelect.getSelection(root, criteriaBuilder);
        return entityManager.createQuery(this.getFindAllQuery(jpaSelectInfo, criteriaQuery)).getResultList();
    }

    @SuppressWarnings("unchecked")
    private CriteriaQuery<T> getFindAllQuery(JpaSelection jpaSelectInfo, CriteriaQuery<T> criteriaQuery) {
        if (null != jpaSelectInfo) {
            Selection[] selections = jpaSelectInfo.getSelections();
            List<Predicate> predicates = jpaSelectInfo.getPredicates();
            Expression[] expressions = jpaSelectInfo.getExpressions();
            Order[] orders = jpaSelectInfo.getOrders();
            if (!ObjectUtils.isEmpty(selections)) {
                criteriaQuery.multiselect(selections);
            }
            if (!ObjectUtils.isEmpty(predicates)) {
                criteriaQuery.where(predicates.toArray(new Predicate[predicates.size()]));
            }
            if (!ObjectUtils.isEmpty(expressions)) {
                criteriaQuery.groupBy(expressions);
            }
            if (!ObjectUtils.isEmpty(orders)) {
                criteriaQuery.orderBy(orders);
            }
        }
        return criteriaQuery;
    }

    private Long getCount(JpaSelection jpaSelectInfo, CriteriaBuilder criteriaBuilder) {
        CriteriaQuery<Long> countQuery = criteriaBuilder.createQuery(Long.class);
        countQuery.select(criteriaBuilder.count(countQuery.from(domainClass)));
        List<Predicate> predicates = jpaSelectInfo.getPredicates();
        if (predicates != null && predicates.size() > 0) {
            Predicate[] p = new Predicate[predicates.size()];
            countQuery.where(predicates.toArray(p));
        }
        return entityManager.createQuery(countQuery).getSingleResult();
    }

}